import os
import json
import time
from colorama import init, Fore, Style
from groq import Groq
import requests

from ..utils.api_utils import load_prompt_options, get_prompt_content
from ..utils.env_manager import ensure_env_file, get_api_key

init()  # Initialize colorama

class GroqAPIALMTranscribe:
    DEFAULT_PROMPT = "Transcribe the audio file"

    # Supported models for transcription
    TRANSCRIPTION_MODELS = [
        "whisper-large-v3-turbo",
        "distil-whisper-large-v3-en",
        "whisper-large-v3",
    ]

    SUPPORTED_AUDIO_FORMATS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm']

    CLASS_TYPE = "text"  # Necessary for node recognition

    def __init__(self):
        # Set up directories for prompt files
        current_directory = os.path.dirname(os.path.realpath(__file__))
        groq_directory = os.path.join(current_directory, 'groq')
        
        # Get API key from env file
        ensure_env_file()
        self.api_key = get_api_key()
        self.client = Groq(api_key=self.api_key)
        
        # Load prompt options
        prompt_files = [
            os.path.join(groq_directory, 'DefaultPrompts_ALM_Transcribe.json'),
            os.path.join(groq_directory, 'UserPrompts_ALM_Transcribe.json')
        ]
        self.prompt_options = load_prompt_options(prompt_files)

    @classmethod
    def INPUT_TYPES(cls):
        try:
            current_directory = os.path.dirname(os.path.realpath(__file__))
            groq_directory = os.path.join(current_directory, 'groq')
            prompt_files = [
                os.path.join(groq_directory, 'DefaultPrompts_ALM_Transcribe.json'),
                os.path.join(groq_directory, 'UserPrompts_ALM_Transcribe.json')
            ]
            prompt_options = load_prompt_options(prompt_files)
        except Exception as e:
            print(Fore.RED + f"Failed to load prompt options: {e}" + Style.RESET_ALL)
            prompt_options = {}

        return {
            "required": {
                "model": (cls.TRANSCRIPTION_MODELS, {"tooltip": "Select the transcription model to use."}),
                "file_path": ("STRING", {"label": "Audio file path", "multiline": False, "default": "", "tooltip": "Path to the audio file to be transcribed."}),
                "preset": ([cls.DEFAULT_PROMPT] + list(prompt_options.keys()), {"tooltip": "Select a preset for the transcription or custom prompts."}),
                "user_input": ("STRING", {"label": "User Input (for prompt)", "multiline": True, "default": "", "tooltip": "Optional user input to guide the transcription."}),
                "response_format": (["text", "text_with_linebreaks", "text_with_timestamps", "json", "verbose_json"], {"tooltip": "Format in which the transcription response is returned.\n\nText: Only the text, in one text chunk.\n\ntext_with_linebreaks: Only the text, with each line separated by a line break.\n\ntext_with_timestamps: Only the text, with each timestamp is separated by a line break.\n\njson: The JSON response from the API.\n\nverbose_json: The JSON response from the API with more details."}),
                "temperature": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Controls randomness in responses.\n\nA higher temperature makes the model take more risks, leading to more creative or varied answers.\n\nA lower temperature (closer to 0.1) makes the model more focused and predictable."}),
                "language": ("STRING", {"label": "Language (ISO 639-1 code, e.g., 'en', 'fr')", "default": "en", "multiline": False, "tooltip": "Language of the audio file in ISO 639-1 code.\nhttps://www.wikiwand.com/en/articles/List_of_ISO_639_language_codes\n\nis tg uz zh ru tr hi la tk haw fr vi cs hu kk he cy bs sw ht mn gl si mg sa es ja pt lt mr fa sl kn uk ms ta hr bg pa yi fo th lv ln ca br sq jv sn gu ba te bn et sd tl ha de hy so oc nn az km yo ko pl da mi ml ka am tt su yue nl no ne mt my ur ps ar id fi el ro as en it sk be lo lb bo sv sr mk eu"}),
                "max_retries": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1, "tooltip": "Maximum number of retries in case of transcription failures."}),
            }
        }

    RETURN_TYPES = ("STRING", "BOOLEAN", "STRING")
    RETURN_NAMES = ("transcription_result", "success", "status_code")
    OUTPUT_TOOLTIPS = ("The API response. This is the transcription generated by the model", "Whether the request was successful", "The status code of the request")

    FUNCTION = "process_transcription_request"
    CATEGORY = "⚡ MNeMiC Nodes"
    DESCRIPTION = "Uses Groq API to transcribe audio."
    OUTPUT_NODE = True

    def process_transcription_request(self, model, file_path, preset, user_input, response_format, temperature, language, max_retries):
        # Validate file path
        if not os.path.isfile(file_path):
            print(Fore.RED + f"Error: File not found at path {file_path}" + Style.RESET_ALL)
            return "File not found.", False, "400 Bad Request"

        # Validate file extension
        file_extension = file_path.split('.')[-1].lower()
        if file_extension not in self.SUPPORTED_AUDIO_FORMATS:
            print(Fore.RED + f"Error: Unsupported audio format '{file_extension}'. Supported formats are: {', '.join(self.SUPPORTED_AUDIO_FORMATS)}" + Style.RESET_ALL)
            return f"Unsupported audio format '{file_extension}'.", False, "400 Bad Request"

        # Load the audio file
        try:
            with open(file_path, 'rb') as audio_file:
                audio_data = audio_file.read()
            print("Audio file loaded successfully.")
        except Exception as e:
            print(Fore.RED + f"Error reading audio file: {str(e)}" + Style.RESET_ALL)
            return "Error reading audio file.", False, "400 Bad Request"

        # Prepare the prompt
        if preset == self.DEFAULT_PROMPT:
            prompt = self.DEFAULT_PROMPT.replace('[user_input]', user_input.strip()) if user_input else ''
        else:
            prompt_template = get_prompt_content(self.prompt_options, preset)
            prompt = prompt_template.replace('[user_input]', user_input.strip()) if user_input else prompt_template

        print(f"Using prompt: {prompt}")

        # Limit the prompt to 224 tokens
        # if prompt:
        #    prompt = prompt[:1000]

        # Adjust api_response_format based on response_format
        if response_format in ['json', 'verbose_json', 'text']:
            api_response_format = response_format
        elif response_format in ['text_with_timestamps', 'text_with_linebreaks']:
            api_response_format = 'verbose_json'
        else:
            print(Fore.RED + "Unknown response format selected." + Style.RESET_ALL)
            return "Unknown response format.", False, "400 Bad Request"

        url = 'https://api.groq.com/openai/v1/audio/transcriptions'
        headers = {'Authorization': f'Bearer {self.api_key}'}
        files = {'file': (os.path.basename(file_path), audio_data)}
        data = {
            'model': model,
            'response_format': api_response_format,
            'temperature': str(temperature),
            'language': language or 'en'  # Default to 'en' if not specified
        }
        if prompt:
            data['prompt'] = prompt

        print(f"Sending request to {url} with data: {data} and headers: {headers}")

        # Send the request
        for attempt in range(max_retries):
            try:
                print(f"Attempt {attempt + 1} of {max_retries}")
                response = requests.post(url, headers=headers, data=data, files=files)
                print(f"Response status: {response.status_code}")
                if response.status_code == 200:
                    print("Request successful.")
                    if api_response_format == "text":
                        if response_format == "text":
                            # Return plain text as is
                            return response.text, True, "200 OK"
                    elif api_response_format in ["json", "verbose_json"]:
                        try:
                            response_json = json.loads(response.text)
                        except Exception as e:
                            print(Fore.RED + f"Error parsing JSON response: {str(e)}" + Style.RESET_ALL)
                            return "Error parsing JSON response.", False, "200 OK but failed to parse JSON"
                        if response_format == "json":
                            # Return JSON as formatted string
                            return json.dumps(response_json, indent=4), True, "200 OK"
                        elif response_format == "verbose_json":
                            # Return verbose JSON as formatted string
                            return json.dumps(response_json, indent=4), True, "200 OK"
                        elif response_format == "text_with_timestamps":
                            # Process segments to produce line-based timestamps
                            segments = response_json.get('segments', [])
                            transcription_text = ""
                            for segment in segments:
                                start_time = segment.get('start', 0)
                                # Convert start_time to minutes:seconds.milliseconds
                                minutes = int(start_time // 60)
                                seconds = int(start_time % 60)
                                milliseconds = int((start_time - int(start_time)) * 1000)
                                timestamp = f"[{minutes:02d}:{seconds:02d}.{milliseconds:03d}]"
                                text = segment.get('text', '').strip()
                                transcription_text += f"{timestamp}{text}\n"
                            return transcription_text.strip(), True, "200 OK"
                        elif response_format == "text_with_linebreaks":
                            # Extract text from each segment and concatenate with line breaks
                            segments = response_json.get('segments', [])
                            transcription_text = ""
                            for segment in segments:
                                text = segment.get('text', '').strip()
                                transcription_text += f"{text}\n"
                            return transcription_text.strip(), True, "200 OK"
                    else:
                        print(Fore.RED + "Unknown api_response_format." + Style.RESET_ALL)
                        return "Unknown api_response_format.", False, "400 Bad Request"
                else:
                    print(Fore.RED + f"Error: {response.status_code} {response.reason}" + Style.RESET_ALL)
                    print(f"Response body: {response.text}")
                    return response.text, False, f"{response.status_code} {response.reason}"
            except Exception as e:
                print(Fore.RED + f"Request failed: {str(e)}" + Style.RESET_ALL)
                time.sleep(2)
        print(Fore.RED + "Failed after all retries." + Style.RESET_ALL)
        return "Failed after all retries.", False, "Failed after all retries"
